import euclideanDistance from "./euclidean_distance.js";
import makeMatrix from "./make_matrix.js";
import sample from "./sample.js";

/**
 * @typedef {Object} kMeansReturn
 * @property {Array<number>} labels 数据点所属簇的标签数组
 * @property {Array<Array<number>>} centroids 聚类中心坐标数组
 */

/**
 * 执行k均值聚类算法
 *
 * @param {Array<Array<number>>} points 待聚类点的N维坐标数组
 * @param {number} numCluster 需要创建的聚类数量
 * @param {Function} randomSource 可选随机源，生成[0,1)区间的均匀分布值
 * @return {kMeansReturn} 包含标签数组和聚类中心数组的对象
 * @throws {Error} 当存在无关联数据点的聚类中心时抛出异常
 *
 * @example
 * kMeansCluster([[0.0, 0.5], [1.0, 0.5]], 2); // => {labels: [0, 1], centroids: [[0.0, 0.5], [1.0 0.5]]}
 */
function kMeansCluster(points, numCluster, randomSource = Math.random) {
    let oldCentroids = null;
    let newCentroids = sample(points, numCluster, randomSource);
    let labels = null;
    let change = Number.MAX_VALUE;
    while (change !== 0) {
        labels = labelPoints(points, newCentroids);
        oldCentroids = newCentroids;
        newCentroids = calculateCentroids(points, labels, numCluster);
        change = calculateChange(newCentroids, oldCentroids);
    }
    return {
        labels: labels,
        centroids: newCentroids
    };
}

/**
 * 根据当前聚类中心为数据点分配簇标签
 *
 * @private
 * @param {Array<Array<number>>} points 数据点坐标数组
 * @param {Array<Array<number>>} centroids 当前聚类中心坐标数组
 * @return {Array<number>} 数据点对应的簇标签数组
 */
function labelPoints(points, centroids) {
    return points.map((p) => {
        let minDist = Number.MAX_VALUE;
        let label = -1;
        for (let i = 0; i < centroids.length; i++) {
            const dist = euclideanDistance(p, centroids[i]);
            if (dist < minDist) {
                minDist = dist;
                label = i;
            }
        }
        return label;
    });
}

/**
 * 根据数据点标签计算新的聚类中心
 *
 * @private
 * @param {Array<Array<number>>} points 数据点坐标数组
 * @param {Array<number>} labels 数据点簇标签数组
 * @param {number} numCluster 聚类总数
 * @return {Array<Array<number>>} 计算得到的新聚类中心数组
 * @throws {Error} 当存在无关联数据点的聚类中心时抛出异常
 */
function calculateCentroids(points, labels, numCluster) {
    // 初始化累加器和计数器
    const dimension = points[0].length;
    const centroids = makeMatrix(numCluster, dimension);
    const counts = Array(numCluster).fill(0);

    // 累加各簇数据点坐标并计数
    const numPoints = points.length;
    for (let i = 0; i < numPoints; i++) {
        const point = points[i];
        const label = labels[i];
        const current = centroids[label];
        for (let j = 0; j < dimension; j++) {
            current[j] += point[j];
        }
        counts[label] += 1;
    }

    // 计算均值并检查空簇
    for (let i = 0; i < numCluster; i++) {
        if (counts[i] === 0) {
            throw new Error(`聚类中心${i}无关联数据点`);
        }
        const centroid = centroids[i];
        for (let j = 0; j < dimension; j++) {
            centroid[j] /= counts[i];
        }
    }

    return centroids;
}

/**
 * 计算新旧聚类中心的总变化量
 *
 * @private
 * @param {Array<Array<number>>} left 新聚类中心数组
 * @param {Array<Array<number>>} right 旧聚类中心数组
 * @return {number} 聚类中心坐标变化总和
 */
function calculateChange(left, right) {
    let total = 0;
    for (let i = 0; i < left.length; i++) {
        total += euclideanDistance(left[i], right[i]);
    }
    return total;
}

export default kMeansCluster;
